package com.fly.socket.websocket;

import java.io.IOException;
import java.util.Date;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import com.fly.socket.po.Message;
import com.fly.socket.service.LoginService;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;

/**
 * Socket处理器(文本TextWebSocketHandler，二进制流BinaryWebSocketHandler)
 */
@Component
public class MyWebSocketHandler extends TextWebSocketHandler
{
    private static final Logger LOGGER = LoggerFactory.getLogger(MyWebSocketHandler.class);
    
    // 用于保存HttpSession与WebSocketSession的映射关系
    public static final Map<Long, WebSocketSession> USER_SOCKET_SESSION_MAP;
    
    @Autowired
    LoginService loginservice;
    
    static
    {
        USER_SOCKET_SESSION_MAP = new ConcurrentHashMap<>();
    }
    
    /**
     * 建立连接后,把登录用户的id写入WebSocketSession
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session)
        throws Exception
    {
        Long uid = (Long)session.getAttributes().get("uid");
        String username = loginservice.getnamebyid(uid);
        if (USER_SOCKET_SESSION_MAP.get(uid) == null)
        {
            USER_SOCKET_SESSION_MAP.put(uid, session);
            Message msg = new Message();
            msg.setFrom(0L);// 0表示上线消息
            msg.setText(username);
            this.broadcast(new TextMessage(new GsonBuilder().setDateFormat("yyyy-MM-dd HH:mm:ss").create().toJson(msg)));
        }
    }
    
    /**
     * 消息处理，在客户端通过Websocket API发送的消息会经过这里，然后进行相应的处理
     */
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message)
        throws Exception
    {
        if (message.getPayloadLength() == 0)
        {
            return;
        }
        Message msg = new Gson().fromJson(message.getPayload().toString(), Message.class);
        msg.setDate(new Date());
        sendMessageToUser(msg.getTo(), new TextMessage(new GsonBuilder().setDateFormat("yyyy-MM-dd HH:mm:ss").create().toJson(msg)));
    }
    
    /**
     * 消息传输错误处理
     */
    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception)
        throws Exception
    {
        if (session.isOpen())
        {
            session.close();
        }
        Iterator<Entry<Long, WebSocketSession>> it = USER_SOCKET_SESSION_MAP.entrySet().iterator();
        // 移除当前抛出异常用户的Socket会话
        while (it.hasNext())
        {
            Entry<Long, WebSocketSession> entry = it.next();
            if (entry.getValue().getId().equals(session.getId()))
            {
                USER_SOCKET_SESSION_MAP.remove(entry.getKey());
                LOGGER.info("Socket会话已经移除:用户ID {}", entry.getKey());
                String username = loginservice.getnamebyid(entry.getKey());
                Message msg = new Message();
                msg.setFrom(-2L);
                msg.setText(username);
                this.broadcast(new TextMessage(new GsonBuilder().setDateFormat("yyyy-MM-dd HH:mm:ss").create().toJson(msg)));
                break;
            }
        }
    }
    
    /**
     * 关闭连接后
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus)
        throws Exception
    {
        LOGGER.info("Websocket:{} 已经关闭", session.getId());
        Iterator<Entry<Long, WebSocketSession>> it = USER_SOCKET_SESSION_MAP.entrySet().iterator();
        // 移除当前用户的Socket会话
        while (it.hasNext())
        {
            Entry<Long, WebSocketSession> entry = it.next();
            if (entry.getValue().getId().equals(session.getId()))
            {
                USER_SOCKET_SESSION_MAP.remove(entry.getKey());
                LOGGER.info("Socket会话已经移除:用户ID {}", entry.getKey());
                String username = loginservice.getnamebyid(entry.getKey());
                Message msg = new Message();
                msg.setFrom(-2L);// 下线消息，用-2表示
                msg.setText(username);
                this.broadcast(new TextMessage(new GsonBuilder().setDateFormat("yyyy-MM-dd HH:mm:ss").create().toJson(msg)));
                break;
            }
        }
    }
    
    @Override
    public boolean supportsPartialMessages()
    {
        return false;
    }
    
    /**
     * 给所有在线用户发送消息
     * 
     * @param message
     * @throws IOException
     */
    public void broadcast(final TextMessage message)
        throws IOException
    {
        Iterator<Entry<Long, WebSocketSession>> it = USER_SOCKET_SESSION_MAP.entrySet().iterator();
        // 多线程群发
        while (it.hasNext())
        {
            final Entry<Long, WebSocketSession> entry = it.next();
            if (entry.getValue().isOpen())
            {
                new Thread(() -> {
                    try
                    {
                        if (entry.getValue().isOpen())
                        {
                            entry.getValue().sendMessage(message);
                        }
                    }
                    catch (IOException e)
                    {
                        LOGGER.error("broadcast error", e);
                    }
                }).start();
            }
        }
    }
    
    /**
     * 给某个用户发送消息
     * 
     * @param userName
     * @param message
     * @throws IOException
     */
    public void sendMessageToUser(Long uid, TextMessage message)
        throws IOException
    {
        WebSocketSession session = USER_SOCKET_SESSION_MAP.get(uid);
        if (session != null && session.isOpen())
        {
            session.sendMessage(message);
        }
    }
    
}
